#!/usr/bin/python

# TODO
# - multiprocessing for splitting
# - distance-based thresholds
# - Zipf law/frequency-based thresholds

import sys
sys.path = ["."]+sys.path
import pdb
from pdb import pm

import re
import random as pyrand
from pylab import *
import tables
import ocrolib; reload(ocrolib)
from collections import Counter
from ocrolib import patrec; reload(patrec)
from scipy.spatial import distance
from ocrolib.patrec import Dataset,cshow,showgrid
from ocrolib.ligatures import lig



import argparse
parser = argparse.ArgumentParser(
    description="Compute hierarchical tree vector quantizers.",
    epilog="You must supply a data file (-d data.h5)")
parser.add_argument('--show',default=None,help="display the first two levels of a splitter")
parser.add_argument('--show1',default=None,help="display the first level of a splitter")
parser.add_argument('-N','--maxtrain',type=int,default=10000000000,help="max # of training samples")
parser.add_argument('-d','--data',default=None,help="data file",required=1)
parser.add_argument('-o','--output',default="splitter.smodel",help="output file, default: %(default)s")
parser.add_argument('-q','--quiet',action="store_true",help="output less stuff")
parser.add_argument('--vq',type=int,default=None,help="set up parameters suitable for simple vector quantization")
parser.add_argument('--pca',type=float,default=0.95,help="PCA parameters, <1=fraction, >=1=number of components, default: %(default)s")
parser.add_argument('--maxsplit',type=int,default=100,help="max # vectors at each level, default: %(default)s")
parser.add_argument('--maxdepth',type=int,default=3,help="maximum tree depth, default: %(default)s")
parser.add_argument('--splitsize',type=int,default=5000,help="size above which a node is split, default: %(default)s")
parser.add_argument('--targetsize',type=int,default=500,help="target size for each subtree to determine actual split size, default: %(default)s")
parser.add_argument('--exclude',default=r'[ _\000-\037]',help="classes to exclude from splitting, default: %(default)s")
parser.add_argument('-Q','--parallel',type=int,default=0,help="number of CPUs used for training")
args = parser.parse_args()



if args.show1 is not None:
    print "loading model"
    import cPickle
    with open(args.show1) as stream:
        sc = cPickle.load(stream)
    assert sc.splitter is not None
    print "got model",sc,len(sc.splitter.centers())
    centers = sc.splitter.centers()
    h,w = 32,32
    assert h*w==centers[0].size
    nc = int(0.5+sqrt(len(centers)))
    nr = int((len(centers)+nc-1)/nc)
    output = zeros(((h+1)*nr,(w+1)*nc))
    def p(i,j,image):
        image = image.reshape(h,w)
        output[i*(h+1):i*(h+1)+h,j*(w+1):j*(w+1)+w] = image
    for i,v in enumerate(centers):
        #print nr,nc,i
        p(i//nc,i%nc,v.reshape(h,w))
    imshow(output)
    gray()
    show()
    sys.exit(0)
    

if args.show is not None:
    print "loading model"
    import cPickle
    with open(args.show) as stream:
        sc = cPickle.load(stream)
    assert sc.splitter is not None
    print "got model",sc,len(sc.splitter.centers())
    centers = sc.splitter.centers()
    h,w = 32,32
    assert h*w==centers[0].size
    nc = amax([(len(sc.subs[i].subs) if sc.subs[i] is not None else 0) for i,_ in enumerate(sc.subs)])
    nc += 2
    nc = minimum(100,nc)
    nr = len(centers)
    output = zeros(((h+1)*nr,(w+1)*nc))
    def p(i,j,image):
        image = image.reshape(h,w)
        output[i*(h+1):i*(h+1)+h,j*(w+1):j*(w+1)+w] = image
    for i in range(nr):
        p(i,0,centers[i])
        if sc.subs[i] is not None:
            subcenters = sc.subs[i].splitter.centers()
            for j,v in enumerate(subcenters[:nc-2]):
                im = v.reshape(h,v)
                p(i,j+2,v.reshape(h,v))
    imshow(output)
    gray()
    show()
    sys.exit(0)
                


if args.data is None:
    parser.print_help()
    sys.exit(0)

if args.vq is not None:
    args.maxsplit = args.vq
    args.targetsize = 1
    args.maxdepth = 0

print "loading dataset"
h5 = tables.openFile(args.data,"r")
N = min(args.maxtrain,len(h5.root.classes))
classes = h5.root.classes[:N]
cclasses = [lig.chr(c) for c in classes]
subset = [i for i,c in enumerate(cclasses) if c!="" and not re.search(args.exclude,c)]
cclasses = [cclasses[i] for i in subset]
patches = Dataset(h5.root.patches,subset=subset)

print "got",len(subset),"samples out of",N
assert len(subset)<=args.maxtrain
counter = Counter(cclasses)
print "# classes",len(counter.keys())
print "most common",
for k,v in counter.most_common(10): print k,v,"/",
print "..."



print "starting training"
sc = patrec.HierarchicalSplitter(d=args.pca,maxsplit=args.maxsplit,
                                 splitsize=args.splitsize,targetsize=args.targetsize,
                                 maxdepth=args.maxdepth,quiet=args.quiet)
sc.sizemode = (h5.getNodeAttr("/","sizemode") or "perchar")
sc.fit(patches)

print "writing"
import cPickle
with open(args.output,"w") as stream:
    cPickle.dump(sc,stream,2)

h5.close()
del h5
